{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SQL\n",
    "\n",
    "이번 튜토리얼은 `create_sql_query_chain` 을 활용하여 생성한 체인으로 SQL 쿼리를 생성 후 실행, 그리고 답변 도출하는 방법에 대해서 다룹니다. 그리고, SQL Agent 와의 동작 방식의 차이도 알아보겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# API KEY를 환경변수로 관리하기 위한 설정 파일\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "# API KEY 정보로드\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LangSmith 추적을 시작합니다.\n",
      "[프로젝트명]\n",
      "SQL\n"
     ]
    }
   ],
   "source": [
    "# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
    "# !pip install langchain-teddynote\n",
    "from langchain_teddynote import logging\n",
    "\n",
    "# 프로젝트 이름을 입력합니다.\n",
    "logging.langsmith(\"SQL\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SQL Database 정보를 불러옵니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sqlite\n",
      "['accounts', 'customers', 'transactions']\n"
     ]
    }
   ],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.chains import create_sql_query_chain\n",
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "# SQLite 데이터베이스에 연결합니다.\n",
    "db = SQLDatabase.from_uri(\"sqlite:///data/finance.db\")\n",
    "\n",
    "# 데이터베이스의을 출력합니다.\n",
    "print(db.dialect)\n",
    "\n",
    "# 사용 가능한 테이블 이름들을 출력합니다.\n",
    "print(db.get_usable_table_names())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LLM 객체를 생성하고 LLM 과 DB 를 매개변수로 입력하여 chain 을 생성합니다.\n",
    "\n",
    "여기서 model 변경시 원활하게 동작하지 않을 수 있어, 이번 튜토리얼은 `gpt-3.5-turbo` 로 진행합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model 은 gpt-3.5-turbo 를 지정\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "\n",
    "# LLM 과 DB 를 매개변수로 입력하여 chain 을 생성합니다.\n",
    "chain = create_sql_query_chain(llm, db)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(옵션) 아래의 방식으로 Prompt 를 직접 지정할 수 있습니다.\n",
    "\n",
    "직접 작성시 table_info 와 더불어 설명가능한 column description 을 추가할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import PromptTemplate\n",
    "\n",
    "prompt = PromptTemplate.from_template(\n",
    "    \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
    "Use the following format:\n",
    "\n",
    "Question: \"Question here\"\n",
    "SQLQuery: \"SQL Query to run\"\n",
    "SQLResult: \"Result of the SQLQuery\"\n",
    "Answer: \"Final answer here\"\n",
    "\n",
    "Only use the following tables:\n",
    "{table_info}\n",
    "\n",
    "Here is the description of the columns in the tables:\n",
    "`cust`: customer name\n",
    "`prod`: product name\n",
    "`trans`: transaction date\n",
    "\n",
    "Question: {input}\"\"\"\n",
    ").partial(dialect=db.dialect)\n",
    "\n",
    "# model 은 gpt-3.5-turbo 를 지정\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "\n",
    "# LLM 과 DB 를 매개변수로 입력하여 chain 을 생성합니다.\n",
    "chain = create_sql_query_chain(llm, db, prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "chain 을 실행하면 DB 기반으로 쿼리를 생성합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'SELECT name\\nFROM customers'\n"
     ]
    }
   ],
   "source": [
    "# chain 을 실행하고 결과를 출력합니다.\n",
    "generated_sql_query = chain.invoke({\"question\": \"고객의 이름을 나열하세요\"})\n",
    "\n",
    "# 생성된 쿼리를 출력합니다.\n",
    "print(generated_sql_query.__repr__())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음은 생성한 쿼리가 맞게 동작하는지 확인해 볼 차례입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n",
    "\n",
    "# 생성한 쿼리를 실행하기 위한 도구를 생성합니다.\n",
    "execute_query = QuerySQLDataBaseTool(db=db)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[('테디',), ('폴',), ('셜리',), ('민수',), ('지영',), ('은정',)]\""
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "execute_query.invoke({\"query\": generated_sql_query})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool\n",
    "\n",
    "# 도구\n",
    "execute_query = QuerySQLDataBaseTool(db=db)\n",
    "\n",
    "# SQL 쿼리 생성 체인\n",
    "write_query = create_sql_query_chain(llm, db)\n",
    "\n",
    "# 생성한 쿼리를 실행하기 위한 체인을 생성합니다.\n",
    "chain = write_query | execute_query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"[('teddy@example.com',)]\""
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 실행 결과 확인\n",
    "chain.invoke({\"question\": \"테디의 이메일을 조회하세요\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 답변을 LLM 으로 증강-생성\n",
    "\n",
    "이전 단계에서 생성한 chain 을 사용하면 답변이 단답형 형식으로 출력됩니다. 이를 LCEL 문법의 체인으로 좀 더 자연스러운 답변을 받을 수 있도록 조정할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from operator import itemgetter\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "answer_prompt = PromptTemplate.from_template(\n",
    "    \"\"\"Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n",
    "\n",
    "Question: {question}\n",
    "SQL Query: {query}\n",
    "SQL Result: {result}\n",
    "Answer: \"\"\"\n",
    ")\n",
    "\n",
    "answer = answer_prompt | llm | StrOutputParser()\n",
    "\n",
    "# 생성한 쿼리를 실행하고 결과를 출력하기 위한 체인을 생성합니다.\n",
    "chain = (\n",
    "    RunnablePassthrough.assign(query=write_query).assign(\n",
    "        result=itemgetter(\"query\") | execute_query\n",
    "    )\n",
    "    | answer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 실행 결과 확인\n",
    "chain.invoke({\"question\": \"테디의 transaction 의 합계를 구하세요\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agent\n",
    "\n",
    "Agent를 활용하여 Sql 쿼리를 생성하고 실행 결과를 답변으로 출력이 가능합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_community.utilities import SQLDatabase\n",
    "from langchain_community.agent_toolkits import create_sql_agent\n",
    "\n",
    "# model 은 gpt-3.5-turbo 를 지정\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
    "\n",
    "# SQLite 데이터베이스에 연결합니다.\n",
    "db = SQLDatabase.from_uri(\"sqlite:///data/finance.db\")\n",
    "\n",
    "# Agent 생성\n",
    "agent_executor = create_sql_agent(llm, db=db, agent_type=\"openai-tools\", verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 실행 결과 확인\n",
    "agent_executor.invoke(\n",
    "    {\"input\": \"테디와 셜리의 transaction 의 합계를 구하고 비교하세요\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-kr-lwwSZlnu-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
